#!/usr/bin/env python
"""
Extract and store data entries from a set of reStructuredText documents.

Entries look like this (all names are generic)::

  [book]
    :title: Probability Theory
    :subtitle: The Logic of Science
    :authors: E.T. Jaynes, G. Larry Bretthorst
    :isbn: 978-0521592710

You can have the data stored in an existing database, or in a file. This program
can also automatically infer a database schema from the data (and create the
tables). If you only store the data in a database, only the fields defined in
the schema will be stored (the others will be dropped). This allows you to
create your database tables ahead of time, as you like, and to have only valid
fields filled with the data later.

This script provides an ultra-simple way to store data that can be expressed as
name-value pairs embedded in a text file (with other text). The typical kind of
data that you would use this for would be for PIM data (e.g., books, addresses,
links, etc.).

A special source key is used in the outgoing data to identify the source file
that each data entry came from. By default, the source is the filename, but you
can set this source id in the file itself, by creating an 'Id' docinfo entry, at
the top of the document, like this::

  ========
  My Title
  ========
  :Id: <unique-id>
  ...

About the outputs:

  If <OUTPUT> is a filename, SQL code or CSV is written to a file. However, if a
  connection string is provided, we run the commands on the given database.
  Warning: if an inferred schema definition is to be stored in a database, we
  first drop the data from those tables.

About normalization:

  By default the field names are sanitized by attempting to remove plurals
  automatically, and the value strings are merged into a single line each. This
  can be optionally disabled.

Notes:

  The ideas for this project originate from project Nabu
  (http://furius.ca/nabu/). It is an attempt at providing its basic
  functionality without all the complications, setup and customization that Nabu
  requires.

"""
__author__ = 'Martin Blais <blais@furius.ca>'

# stdlib imports
import sys, os, re, getpass, logging
from types import ClassType
from StringIO import StringIO
from collections import defaultdict
from operator import itemgetter, attrgetter
from os.path import *

# psyco imports
import psycopg2 as dbapi

# docutils imports
from docutils import core, nodes, io, utils


col_source = '__source__'


#-------------------------------------------------------------------------------
# Operations on association lists.

def alist_memq(alist, key):
    "Return true if 'key' is in the assoc-list 'alist'."
    for k, v in alist:
        if k == key:
            return True
        
def alist_replace(alist, from_, to_):
    "Replace keys with 'from_' in alist with keys of 'to_', destructively."
    replace = [(i, v) for i, (k, v) in enumerate(alist) if k == from_]
    for i, v in replace:
        alist[i] = (to_, v)

def alist_remove(alist, key):
    "Remove in-place the elements with key 'key'."
    remove = [i for i, (k, v) in enumerate(alist) if k == key]
    for i in reversed(remove):
        del alist[i]


#-------------------------------------------------------------------------------
# Parsing and values extraction code

# Note: this is generic, perhaps should be part of docutils.nodes.
def find_first(nodetype, node, document):
    """ Find the first node under 'node' that is of the same type as 'nodetype'."""

    assert isinstance(nodetype, ClassType)
    found = []

    class FindFirst(nodes.SparseNodeVisitor):
        def visit(self, node):
            found.append(node)
            return True # stop

    setattr(FindFirst, 'visit_%s' % nodetype.__name__,
            FindFirst.visit)

    vis = FindFirst(document)
    node.walk(vis)

    return found[0] if found else None

def get_file_entries(fn):
    """ Parse a file and extract entries from it. """
    text = open(fn).read()
    document = core.publish_doctree(
        open(fn),
        source_class=io.FileInput,
        # source_path=fn,
        reader_name='standalone',
        parser_name='restructuredtext',
        settings_overrides={'report_level': 'error'},
        )

    # Extract the unique document id.
    docid = None
    docinfo = find_first(nodes.docinfo, document, document)
    if docinfo:
        fields = extract_fields(docinfo, document)
        dfields = dict((k.lower(), v) for k, v in fields)
        docid = dfields.get(u'id', None)
    if docid is None:
        docid = basename(fn)
    
    # Obtain all the data from the document.
    v = FindData(fn, docid, document)
    document.walk(v)
    return docid, v.entries


class Entry(object):
    """ A data entry read from a file. """

    def __init__(self, source, table, values, locator):
        self.source = source
        self.table = table
        self.values = values
        self.locator = locator

    def __str__(self):
        s = StringIO()
        s.write('[%s]  (%s)\n' % (self.table, self.source))
        for x in self.values:
            s.write('  :%s: %s\n' % x)
        return s.getvalue()

class FindData(nodes.SparseNodeVisitor):
    """ A visitor that finds all the definition_list_item which match our
    desired tagging for format."""

    # Regexp for the definition tag.
    tagre = re.compile('[\[\{\(]([a-zA-Z0-9_]+)[\]\}\)]\s*$')

    def __init__(self, filename, docid, *args):
        nodes.SparseNodeVisitor.__init__(self, *args)
        self.filename = filename
        self.docid = docid
        self.entries = []

    def visit_term(self, node):
        if len(node.children) != 1:
            return
        mo = self.tagre.match(node.astext())
        if not mo:
            return

        table = str2table(mo.group(1))
        dlitem = node.parent
        if len(dlitem.children) != 2:
            return

        defn = dlitem.children[1]
        if not isinstance(defn, nodes.definition) or len(defn.children) != 1:
            return
        flist = defn.children[0]
        if not isinstance(flist, nodes.field_list):
            return

        fields = extract_fields(flist, self.document)

        if opts.enclosing_section:
            assert 'section' not in fields, fields
            section = ''
            node_sec = find_enclosing_section(node)
            if node_sec is not None:
                i = node_sec.first_child_matching_class(nodes.title)
                title = node_sec.children[i]
                if title is not None:
                    section = title.astext()
            fields.append( ('section', section) )
                    
        source, line = utils.get_source_line(node)
        e = Entry(self.docid, table, fields,
                  '%s:%s' % (abspath(self.filename), line))
        self.entries.append(e)

        raise nodes.SkipNode()

def find_enclosing_section(node):
    """ Go up the tree until a section node is found. """
    while node:
        if isinstance(node, nodes.section):
            return node
        node = node.parent
    
def str2table(s):
    "Convert a string to a valid table name."
    return s.lower().replace(' ', '_')

def extract_fields(node, document):
    "Return a list of (key, value) pairs from all underlying field_list's."
    v = ExtractFields(document)
    node.walk(v)
    return list(v)

class ExtractFields(nodes.SparseNodeVisitor, list):
    """ A visitor for a field_list that extracts all the name/value pairs. """

    def visit_field_name(self, node):
        key = node.astext()
        if not opts.raw:
            key = key.split()[0] # Use just the first word of the field.
        self.key = key if opts.raw else sanitize_column(key)

    def visit_field_body(self, node):
        value = node.astext()
        value = value if opts.raw else sanitize_value(value)
        self.append( (self.key, value) )
        self.key = None


def sanitize_column(colname):
    "Sanitize columns names for SQL."
    return re.sub('[^a-z]', '_', colname.strip().lower())

def sanitize_value(value):
    "Sanitize value content for SQL."
    lines = [x.strip() for x in value.splitlines()]
    return ' '.join(lines)

def uniquify_keys(entries):
    """ Given a list of entries, modify their values so that the field keys are
    unique."""
    for e in entries:
        ecols = set()
        replace = []
        for i, (key, value) in enumerate(e.values):
            nkey = key
            if nkey in ecols:
                for j in xrange(2, 1000):
                    nkey = '%s_%d' % (key, j) 
                    if nkey not in ecols:
                        break
                else:
                    raise RuntimeError("Internal error in uniquify_keys.")
                replace.append( (i, (nkey, value)) )
            ecols.add(nkey)
        for i, assoc in replace:
            e.values[i] = assoc

#-------------------------------------------------------------------------------
# Table definition inference code.

# Note: this is generic utils code.
def seq2dict(seq, classify_fun):
    """ Given a sequence of objects and a function to classify them, return a
    dict of (key, sublist of objects) whereby 'key' is computed by calling
    'classify_fun' on objects."""
    assert isinstance(seq, (list, tuple)), seq
    r = defaultdict(list)
    for e in seq:
        try:
            r[classify_fun(e)].append(e)
        except Exception:
            pass
    return r

def infer_tables(entries):
    """ Given a list of entries, infer some database models from it. """
    table_entries = seq2dict(entries, attrgetter('table'))
    return dict((table, infer_table(entries))
                for table, entries in table_entries.iteritems())

intre = re.compile('[0-9]+$')
floatre = re.compile('[0-9\.]+$')

min_nb_values_for_type = 5

def infer_table(entries):
    """ Given a list of entries from the same table, infer a table description.
    This returns a list of ('table-name', [(column-name, column-type) ... ]).
    """

    coldata = defaultdict(list)
    sortorder = defaultdict(int)
    for e in entries:
        for i, (key, value) in enumerate(e.values):
            coldata[key].append(value)
            sortorder[key] += i

    coldefs = {}
    for colname, values in coldata.iteritems():
        ctype = unicode
        if len(values) > min_nb_values_for_type:
            if all(intre.match(x) for x in values):
                ctype = int
            elif all(floatre.match(x) for x in values):
                ctype = float
        coldefs[colname] = ctype

    return [(col_source, unicode)] + \
           sorted(coldefs.items(), key=lambda x: sortorder[x[0]])

sqltypes = {int: 'INTEGER',
            float: 'FLOAT',
            unicode: 'TEXT'}

def table2sql(table, tabledef):
    """Generate SQL table definition code given the table name and columns
    definition."""
    lines = ['CREATE TABLE %s (' % table]
    for colname, ctype in tabledef:
        lines.append('  "%s" %s,' % (colname, sqltypes[ctype]))
    lines[-1] = lines[-1][:-1]

    lines.append(');')
    lines.append('')
    return os.linesep.join(lines)


def sanitize_model(infmodel, dbmodel):
    """ Sanitize the model 'infmodel', to better fit 'dbmodel' (if specified,
    may be None). Returns a modified inferred model and some data to be used
    later to sanitize the data entries."""
    
    # Figure out which columns need to be renamed because of plurals.
    column_renamings = []
    for table, columns in infmodel.iteritems():
        try:
            dbcolumns = dict(dbmodel[table])
        except (KeyError, TypeError):
            dbcolumns = {}

        coldict = dict(columns)
        for mo in filter(None, map(re.compile('(.*)s\s*$').match, coldict)):
            sing = mo.group(1)
            if sing in coldict:
                # We have detected a plural, select the one we need.
                plur = '%ss' % sing
                if plur in dbcolumns and sing not in dbcolumns:
                    renaming = table, sing, plur
                else:
                    renaming = table, plur, sing

                logging.warning("In table '%s', renaming %s to %s" % renaming)
                column_renamings.append(renaming)

                # Remove the column in the model definition.
                alist_remove(columns, renaming[1])

    return infmodel, (column_renamings,)

def sanitize_data(entries_list, sani_info):
    """ Apply the sanitization changes required to make the data fit the model
    changes. This modifies the entries in-place and does not return anything."""
    (column_renamings,) = sani_info

    # Rename all the columns we need to.
    for table, from_, to_ in column_renamings:
        for e in entries_list:
            if alist_memq(e.values, from_):
                alist_replace(e.values, from_, to_)


#-------------------------------------------------------------------------------
# Database introspection.

def db_get_tables(conn):
    "List all the tables of the database."
    curs = conn.cursor()
    curs.execute("""
      SELECT table_name FROM information_schema.tables
        WHERE table_schema = 'public';
        """)
    return [x[0] for x in curs]

def db_get_model(conn):
    "Obtain the definition of database tables and columns."
    dbmodel = {}
    for table in db_get_tables(conn):
        dbmodel[table] = db_get_table_columns(conn, table)
    return dbmodel

def db_get_table_columns(conn, table):
    "List all the columns of a table in the database."
    curs = conn.cursor()
    curs.execute("""
      SELECT column_name, data_type FROM information_schema.columns
        WHERE table_schema = 'public' AND
              table_name = %s
        """, (table,))
    return list(curs)




#-------------------------------------------------------------------------------
# Filling up the database.

progress_count = 100

def store_entries(entries_list, dbmodel, curs):
    """ Given a list of entries to be stored, try to store as much data as
    possible in the given database model."""

    logging.info("  ... dropping old data")

    # Drop the old data from this document from the previous tables.
    tables = set(x.table for x in entries_list)
    sources = set(x.source for x in entries_list)
    for table in tables:
        if table not in dbmodel:
            continue
        for source in sources:
            curs.execute("DELETE FROM %s WHERE %s = %%s" % (table, col_source),
                         (source,))

    dbmodel = dict((k, dict(v)) for (k,v) in dbmodel.iteritems())
    for i, e in enumerate(entries_list):
        if i > 0 and (i % progress_count) == 0:
            logging.info("  ... progress: %d/%d" % (i, len(entries_list)))

        try:
            dbcols = dbmodel[e.table]
        except KeyError:
            pass # Table for available.

        outcols = [col_source]
        outvalues = [e.source]
        colseen = set()
        for cname, cvalue in e.values:
            cname = cname

            # If the columns cannot be stored in the current model, don't.
            if cname not in dbcols:
                continue

            if cname in colseen:
                logging.warning("Duplicate field name at %s" % e.locator)
                continue
            else:
                colseen.add(cname)

            dbtype = dbcols[cname].lower()
            if dbtype in ('text', 'varchar', 'char'):
                value = unicode(cvalue)
            elif dbtype in ('integer',):
                value = int(cvalue)
            elif dbtype in ('numeric', 'float'):
                value = float(cvalue)
            else:
                raise NotImplementedError("Unsupported type: '%s'" % dbtype)

            outcols.append(cname)
            outvalues.append(value)

        if outvalues:
            curs.execute("""
              INSERT INTO %s (%s) VALUES (%s)
              """ % (e.table,
                     ', '.join('"%s"' % x for x in outcols),
                     ', '.join(['%s'] * len(outvalues))),
                         outvalues)


#-------------------------------------------------------------------------------
# Main program.

def parse_dburi(dburi):
    """ Parse the database connection URI and return a dict of the values that
    allow us to connect to a database."""

    user, passwd, host, dbname = [None] * 4
    mo = re.match('(db|postgres|postgresql)://(?:([^:@]+)'
                  '(?::([^:@]+))?@)?([a-z0-9]+)/([a-z0-9]+)/?$', dburi)
    if mo:
        user, passwd, host, dbname = mo.group(2, 3, 4, 5)
    elif re.match('[a-z]+', dburi):
        dbname = dburi
    else:
        parser.error("Invalid database connection string.")

    if user is None:
        user = getpass.getuser()
    if passwd is None:
        passwd = getpass.getpass()
    if host is None:
        host = 'localhost'

    r = {'user': user,
         'password': passwd,
         'host': host,
         'database': dbname}
    assert None not in r, r
    return r

def isdb(deststr):
    "Return true if the destination string is for a database connection."
    return deststr and bool(re.match('[a-z]+://', deststr))


def main():
    import optparse
    parser = optparse.OptionParser(__doc__.strip())

    parser.add_option('-s', '--schema', action='store', metavar="OUTPUT",
                      help="Infer the schema definition from the data and "
                      "store it at the given URI.")

    parser.add_option('-d', '--data', action='store', metavar="OUTPUT",
                      help="Store the contents at the given URI.")

    parser.add_option('-o', '--schema-and-data', action='store', metavar="OUTPUT",
                      help="Both infer the schema definition and store the "
                      "data at the given URI (equivalent to -s and -d).")

    parser.add_option('-r', '--raw', action='store_true',
                      help="Do not sanitize the values, column names, etc.")

    parser.add_option('-v', '--verbose', action='store_true',
                      help="Turn on verbose output.")

    parser.add_option('-e', '--enclosing-section', action='store_true',
                      help="Add a field for enclosing sections to each entry.")

    global opts; opts, args = parser.parse_args()
    if not args:
        parser.error("You must specific a list of filenames to process.")

    logging.basicConfig(level=logging.INFO if opts.verbose else logging.WARNING,
                        format='%(levelname)-8s: %(message)s')

    if opts.schema_and_data:
        opts.schema = opts.data = opts.schema_and_data
        opts.schema_and_data = None
    if not opts.schema and not opts.data:
        opts.data = '-'

    #---------------------------------------------------------------------------
    logging.info("Processing input files, extract all the data entries.")

    # Disable the conversion of system messages into text.
    nodes.system_message.astext = lambda *args: u''

    entries_by_document = {}
    entries_list = []
    for fn in args:
        logging.info("  %s" % fn)

        docid, entries = get_file_entries(fn)

        entries_by_document[docid] = entries
        entries_list.extend(entries)

    if len(entries_list) == 0:
        logging.info("No entries, exiting.")
        sys.exit(1)

    # Make sure that the values have unique key names.
    uniquify_keys(entries_list)
    
    #---------------------------------------------------------------------------
    logging.info("Inferring model (and read target model if necessary).")

    infmodel = infer_tables(entries_list)

    # If relevant, we read the schema of the target database in order to skew
    # the sanitization process for a better fit to it.
    if opts.data and isdb(opts.data):
        conn = dbapi.connect(**parse_dburi(opts.data))
        try:
            dbmodel = db_get_model(conn)
        finally:
            conn.close()
    else:
        dbmodel = None

    # Sanitize the model.
    if not opts.raw:
        infmodel, sani_info = sanitize_model(
            infmodel, dbmodel if not opts.schema else None)

    #---------------------------------------------------------------------------
    logging.info("Storing the schema.")

    if opts.schema:

        oss = StringIO()
        for table, tabledef in infmodel.iteritems():
            oss.write(table2sql(table, tabledef))
            oss.write('\n')
        sqlcreate = oss.getvalue()


        if isdb(opts.schema):
            # Drop the tables and apply to the given database.
            conn = dbapi.connect(**parse_dburi(opts.schema))
            try:
                curs = conn.cursor()
                for table, _ in infmodel.iteritems():
                    try:
                        curs.execute("DROP TABLE %s" % table)
                    except dbapi.Error:
                        conn.rollback()
                        pass # Ignore "table does not exist" errors.
                    else:
                        conn.commit()

                curs.execute(sqlcreate)
                conn.commit()
            finally:
                conn.close()
        else:
            # Write the table creation to a file.
            f = open(opts.schema, 'w')
            f.write(sqlcreate)
            f.close()


    #---------------------------------------------------------------------------
    logging.info("Storing the data.")

    if not opts.raw:
        logging.info("  ... sanitizing the data")
        sanitize_data(entries_list, sani_info)

    if opts.data:
        if isdb(opts.data):
            logging.info("  ... writing to database")
            
            # Open a connection to the database.
            conn = dbapi.connect(**parse_dburi(opts.data))
            try:
                # Open the database and inspect the model.
                dbmodel = db_get_model(conn)

                # Store the entries as much as we can.

                curs = conn.cursor()
                store_entries(entries_list, dbmodel, curs)
                conn.commit()
            finally:
                conn.close()

        else:
            # Write the data to a CSV file.
            if opts.data == '-':
                f = sys.stdout
            else:
                f = open(opts.data, 'w')

            f.write("DATA")
## FIXME TODO, this should be easy

            f.close()



if __name__ == '__main__':
    ## inspect_db()
    main()







## FIXME: add an option to force an additional field for all the files processed.

## FIXME: fields with * at the end (e.g. comments*) should not get joined/sanitized.

## FIXME: How do we deal with subtypes?  e.g. :p/cell:


